# Written by Zachary Frankel December 2009
# For Biophysics 101 at Harvard College
# This file contains methods for performing a logistic regression
# on imported data

import csv
import sys
import math
import ols
import dataread
from numpy import *
from numpy.random import randn


# turn a phenotype string (ie a key in the dict) into an array for regression
def str_to_arr(phenotype):
    #initialize an empty array
    genotype_array = array([])

    i = 0

    # go through the string appending each number to the array
    while i < len(phenotype):
        genotype_array = append(genotype_array, int(phenotype[i:i+1]))
        i += 1

    # return the appended array
    return genotype_array


# load the regression for a particular phenotype                               
def load_log_reg(phenotype):
    # set a sentinel value to know if you are on the first key later
    sentinel = 0
    
    # for each genotype find the conditional probability
    for combos in  dataread.datalist[phenotype].keys():

        # find the conditional probability of the phenotype givne the key
        prob = dataread.cond_prob(phenotype, combos)

        # caculate the logit of the probability
        logit_prob = math.log(prob) - math.log(1-prob)

        # if it is the first combo, initialize the arrays accordingly
        if sentinel == 0:
            xarr = [str_to_arr(combos)]
            yarr = [logit_prob]

        # for combos after the first, append values accordingly and    
        else:
            xarr = append(xarr, [str_to_arr(combos)],axis=0)
            yarr = append(yarr, [logit_prob], axis=0)
        sentinel += 1

    
    i = 0
    xlist = []
    while i < dataread.snpnum:
        varname = 'x' + str(i)
        xlist.append(varname)
        i += 1
    mymodel = ols.ols(yarr,xarr,'yarr',xlist)    
    return mymodel.p
    

          
        
if len(sys.argv) < 1:
    print "please specify a file name"
    sys.exit(1)
        
# read in filename from command line and open it to read
filename = sys.argv[1]

loaded_data = dataread.load_data(filename, 2)

pvals = load_log_reg(0)
print pvals
a = array([1,0, 1,1])
#a = append(
#str_to_arr('011')
val = -1*inner(a, pvals)
risk = 1/(1+math.exp(val))
print risk



